#include "FloatFormatEncoding.hpp"

#include <cstddef>
#include <cstdint>
#include <iomanip>
#include <ios>
#include <locale>
#include <sstream>
#include <string>
#include <string_view>
#include <system_error>

#include <ystdlib/error_handling/Result.hpp>

namespace clp_s {
namespace {
auto has_matching_exponent_sign_flag(float_format_t format, float_format_t sign_flag) -> bool;
auto has_scientific_notation(float_format_t format) -> bool;
auto is_uppercase_exponent(float_format_t format) -> bool;

auto get_num_exponent_digits(float_format_t format) -> size_t;
auto get_num_significant_digits(float_format_t format) -> size_t;

/**
 * Trims the leading zeros until the number of exponent digits match the value stored in the
 * format. The function attempts to remove up to the difference between current number of
 * exponent digits and leading zeros from the target value stored in the format from the
 * exponent, but stops early if a non-zero digit is encountered to preserve correctness.
 *
 * @param scientific_notation The scientific notation string generated by std::scientific.
 * @param start The start position of trimming. It could be either right after the exponent
 * note (E or e), or the second char after the exponent note if the exponent has a sign.
 * @param num_exp_digits The number of exponent digits stored in the format.
 * @return The scientific notation string with the exponent being trimmed leading zeros.
 */
auto trim_leading_zeros(std::string_view scientific_notation, size_t start, size_t num_exp_digits)
        -> std::string;

/**
 * Convert the scientific notation string to a double value string formatted by the encoded
 * format information.
 *
 * @param scientific_notation The scientific notation string generated by std::scientific.
 * @return The double value string formatted by the format information, or
 * `std::errc::protocol_error` on error.
 */
auto scientific_to_decimal(std::string_view scientific_notation)
        -> ystdlib::error_handling::Result<std::string>;

auto has_matching_exponent_sign_flag(float_format_t format, float_format_t sign_flag) -> bool {
    return sign_flag == (format & cExponentSignFlagMask);
}

auto has_scientific_notation(float_format_t format) -> bool {
    return 0U != (format & cScientificNotationEnabledBit);
}

auto is_uppercase_exponent(float_format_t format) -> bool {
    return cScientificNotationUpperCaseEFlag == (format & cScientificNotationFlagMask);
}

auto get_num_exponent_digits(float_format_t format) -> size_t {
    return static_cast<size_t>((format & cNumExponentDigitsMask) >> cNumExponentDigitsPos) + 1ULL;
}

auto get_num_significant_digits(float_format_t format) -> size_t {
    return static_cast<size_t>((format & cNumSignificantDigitsMask) >> cNumSignificantDigitsPos)
           + 1ULL;
}

auto trim_leading_zeros(std::string_view scientific_notation, size_t start, size_t num_exp_digits)
        -> std::string {
    std::string sci_str{std::string(scientific_notation)};
    if (start >= sci_str.length()) {
        return sci_str;
    }
    size_t actual_number_of_zeros_to_trim{0};
    auto const limit{num_exp_digits > sci_str.length() ? 0ULL : sci_str.length() - num_exp_digits};
    for (size_t i{start}; i < limit; ++i) {
        if ('0' == sci_str[i]) {
            actual_number_of_zeros_to_trim++;
        } else {
            break;
        }
    }
    sci_str.erase(start, actual_number_of_zeros_to_trim);
    return sci_str;
}

auto scientific_to_decimal(std::string_view scientific_notation)
        -> ystdlib::error_handling::Result<std::string> {
    auto sci_str = std::string(scientific_notation);
    bool isNegative = false;
    if (false == std::isdigit(static_cast<unsigned char>(sci_str[0]))) {
        isNegative = true;
        sci_str.erase(0, 1);
    }
    size_t const exp_pos = sci_str.find_first_of("Ee");
    if (std::string::npos == exp_pos || exp_pos + 1 >= sci_str.length()) {
        return std::errc::protocol_error;
    }

    // Split into mantissa and exponent parts
    std::string mantissa_str = sci_str.substr(0, exp_pos);
    int const exponent = std::stoi(sci_str.substr(exp_pos + 1));

    // Remove the decimal point from the mantissa
    size_t const dot_pos = mantissa_str.find('.');
    std::string digits;
    if (dot_pos != std::string::npos) {
        digits = mantissa_str.substr(0, dot_pos) + mantissa_str.substr(dot_pos + 1);
    } else {
        digits = mantissa_str;
    }

    // Adjust position of decimal point based on exponent
    int const decimal_pos
            = std::string::npos == dot_pos ? exponent + 1 : static_cast<int>(dot_pos) + exponent;

    std::string result{""};
    if (isNegative) {
        result = "-";
    }
    if (decimal_pos <= 0) {
        result += "0." + std::string(-decimal_pos, '0') + digits;
    } else if (decimal_pos < static_cast<int>(digits.size())) {
        result += digits.substr(0, decimal_pos) + "." + digits.substr(decimal_pos);
    } else {
        result += digits + std::string(decimal_pos - digits.size(), '0');
    }

    return result;
}
}  // namespace

auto get_float_encoding(std::string_view float_str)
        -> ystdlib::error_handling::Result<float_format_t> {
    if (float_str.empty()) {
        return std::errc::protocol_not_supported;
    }

    auto const dot_pos{float_str.find('.')};
    float_format_t format{};

    size_t const first_digit_pos{'-' == float_str[0] ? 1ULL : 0ULL};
    if ('+' == float_str[0]) {
        return std::errc::protocol_not_supported;
    }
    if (float_str.size() <= first_digit_pos) {
        return std::errc::protocol_not_supported;
    }

    // Check whether it is scientific; if so, whether the exponent is E or e
    size_t exp_pos{float_str.find_first_of("Ee")};
    if (std::string_view::npos != exp_pos) {
        // For scientific numbers we only accept one digit before the decimal
        if (std::string_view::npos != dot_pos && (first_digit_pos + 1ULL) != dot_pos) {
            return std::errc::protocol_not_supported;
        }

        // For scientific numbers we only accept non-zero first digits, unless all digits are 0.
        auto const is_zero{'0' == first_digit_pos};
        if (is_zero && exp_pos != (1 + first_digit_pos)) {
            if (std::string_view::npos == dot_pos) {
                return std::errc::protocol_not_supported;
            }
            for (size_t i{dot_pos + 1}; i < exp_pos; ++i) {
                if ('0' != float_str[i]) {
                    return std::errc::protocol_not_supported;
                }
            }
        }

        // Exponent must be followed by an integer (e.g., "1E" or "1e+" are illegal)
        if (false
            == ((exp_pos + 1 < float_str.length()
                 && std::isdigit(static_cast<unsigned char>(float_str[exp_pos + 1])))
                || (exp_pos + 2 < float_str.length()
                    && ('+' == float_str[exp_pos + 1] || '-' == float_str[exp_pos + 1])
                    && std::isdigit(static_cast<unsigned char>(float_str[exp_pos + 2])))))
        {
            return std::errc::protocol_not_supported;
        }

        format |= 'E' == float_str[exp_pos] ? cScientificNotationUpperCaseEFlag
                                            : cScientificNotationLowerCaseEFlag;

        // Check whether there is a sign for the exponent
        if ('+' == float_str[exp_pos + 1]) {
            format |= cPlusExponentSignFlag;
        } else if ('-' == float_str[exp_pos + 1]) {
            format |= cMinusExponentSignFlag;
        }

        // Set the number of exponent digits
        size_t num_exp_digits{float_str.length() - exp_pos - 1};
        if (false == std::isdigit(static_cast<unsigned char>(float_str[exp_pos + 1]))) {
            if (0ULL == num_exp_digits) {
                return std::errc::protocol_not_supported;
            }
            num_exp_digits--;
        }

        if (num_exp_digits <= 0 || num_exp_digits > 4) {
            return std::errc::protocol_not_supported;
        }

        // If the number is a zero all of the exponent digits must be zero.
        if (is_zero) {
            for (size_t i{float_str.length() - num_exp_digits}; i < float_str.length(); ++i) {
                if ('0' != float_str[i]) {
                    return std::errc::protocol_not_supported;
                }
            }
        }

        format |= static_cast<float_format_t>(num_exp_digits - 1) << cNumExponentDigitsPos;
    } else {
        exp_pos = float_str.length();
    }

    // Find first non-zero digit position
    size_t first_non_zero_frac_digit_pos{first_digit_pos};
    if ('0' == float_str[first_non_zero_frac_digit_pos]) {
        // We don't support prefix zeroes of the form 0N.Y
        if (first_non_zero_frac_digit_pos + 1 < float_str.length()
            && std::isdigit(
                    static_cast<unsigned char>(float_str[first_non_zero_frac_digit_pos + 1])
            ))
        {
            return std::errc::protocol_not_supported;
        }

        // For "0.xxx", find the first non-zero digit after the decimal
        if (std::string_view::npos != dot_pos) {
            for (size_t i{dot_pos + 1}; i < exp_pos; ++i) {
                if ('0' != float_str[i]) {
                    first_non_zero_frac_digit_pos = i;
                    break;
                }
            }
        }
    }

    auto num_significant_digits{exp_pos - first_non_zero_frac_digit_pos};
    if (std::string_view::npos != dot_pos && first_non_zero_frac_digit_pos < dot_pos) {
        num_significant_digits--;
    }

    // Number of significant digits must be greater than zero (e.g., E0 or . is illegal) and less
    // than the maximum supported number of digits (17).
    if (num_significant_digits <= 0 || num_significant_digits > cMaxNumSignificantDigits) {
        return std::errc::protocol_not_supported;
    }

    float_format_t const compressed_num_significant_digits{
            static_cast<float_format_t>(num_significant_digits - 1ULL)
    };
    format |= compressed_num_significant_digits << cNumSignificantDigitsPos;
    return format;
}

auto restore_encoded_float(double value, float_format_t format)
        -> ystdlib::error_handling::Result<std::string> {
    std::ostringstream oss;
    oss.imbue(std::locale::classic());
    auto const num_significant_digits{get_num_significant_digits(format)};
    oss << std::scientific << std::setprecision(static_cast<int>(num_significant_digits) - 1);
    if (false == has_scientific_notation(format)) {
        // Convert the scientific notation to the standard decimal
        oss << value;
        return scientific_to_decimal(oss.str());
    }

    if (is_uppercase_exponent(format)) {
        oss << std::uppercase;
    }
    oss << value;
    auto formatted_double_str = oss.str();
    auto const exp_pos = formatted_double_str.find_first_of("Ee");
    if (std::string::npos == exp_pos || exp_pos + 1 >= formatted_double_str.length()) {
        return std::errc::protocol_error;
    }

    auto const maybe_sign = static_cast<unsigned char>(formatted_double_str[exp_pos + 1]);
    auto const num_exp_digits{get_num_exponent_digits(format)};
    if (has_matching_exponent_sign_flag(format, cEmptyExponentSignFlag)) {
        if ('+' == maybe_sign || '-' == maybe_sign) {
            formatted_double_str.erase(exp_pos + 1, 1);
        }
        if (num_exp_digits < (formatted_double_str.length() - exp_pos - 1)) {
            formatted_double_str
                    = trim_leading_zeros(formatted_double_str, exp_pos + 1, num_exp_digits);
        } else {
            formatted_double_str.insert(
                    exp_pos + 1,
                    num_exp_digits - (formatted_double_str.length() - exp_pos - 1),
                    '0'
            );
        }

        return formatted_double_str;
    }

    if (has_matching_exponent_sign_flag(format, cPlusExponentSignFlag)) {
        if (static_cast<bool>(std::isdigit(maybe_sign))) {
            formatted_double_str.insert(exp_pos + 1, "+");
        } else {
            formatted_double_str[exp_pos + 1] = '+';
        }
    } else if (has_matching_exponent_sign_flag(format, cMinusExponentSignFlag)) {
        if (static_cast<bool>(std::isdigit(maybe_sign))) {
            formatted_double_str.insert(exp_pos + 1, "-");
        } else {
            formatted_double_str[exp_pos + 1] = '-';
        }
    }

    if (num_exp_digits < (formatted_double_str.length() - exp_pos - 2)) {
        formatted_double_str
                = trim_leading_zeros(formatted_double_str, exp_pos + 2, num_exp_digits);
    } else {
        formatted_double_str.insert(
                exp_pos + 2,
                num_exp_digits - (formatted_double_str.length() - exp_pos - 2),
                '0'
        );
    }

    return formatted_double_str;
}
}  // namespace clp_s
